package com.sika.code.standard.db.algorithm.table;

import cn.hutool.core.util.StrUtil;
import com.google.common.collect.Lists;
import com.sika.code.exception.BusinessException;
import com.sika.code.standard.db.algorithm.BaseShardingAlgorithm;
import com.sika.code.standard.db.algorithm.ShardingValueContext;
import com.sika.code.standard.db.util.HintManagerHandler;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.api.sharding.complex.ComplexKeysShardingValue;
import org.apache.shardingsphere.api.sharding.hint.HintShardingValue;
import org.apache.shardingsphere.api.sharding.standard.PreciseShardingValue;

import java.util.Collection;

/**
 * @author sikadai
 * @Description: 基础分片算法
 * @date 2021/7/415:17
 */
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public abstract class BaseTableShardingAlgorithm<T extends Comparable<?>> implements BaseShardingAlgorithm<T> {
    private volatile String dataBaseNameExpression;
    private volatile String tableNameExpression;
    protected static final char FILLED_CHAR_ZERO = '0';
    protected static final int LENGTH_FOUR = 4;
    protected static final int LENGTH_TWO = 2;

    @Override
    public Collection<String> doSharding(Collection<String> availableTargetNames, ComplexKeysShardingValue<T> shardingValue) {
        ShardingValueContext<T> shardingValueContext = buildShardingValueContext(availableTargetNames, shardingValue);
        return buildDatabaseAndTableName(shardingValueContext);
    }

    @Override
    public Collection<String> doSharding(Collection<String> availableTargetNames, HintShardingValue<T> shardingValue) {
        ShardingValueContext<T> shardingValueContext = buildShardingValueContext(availableTargetNames, shardingValue);
        return buildDatabaseAndTableName(shardingValueContext);
    }

    @Override
    public String doSharding(Collection<String> availableTargetNames, PreciseShardingValue<T> shardingValue) {
        ShardingValueContext<T> shardingValueContext = buildShardingValueContext(availableTargetNames, shardingValue);
        return buildDatabaseAndTableName(shardingValueContext).iterator().next();
    }

    public Collection<String> buildDatabaseAndTableName(ShardingValueContext<T> shardingValueContext) {
        String dataName = getDataBaseName(shardingValueContext);
        String tableName = getTableName(shardingValueContext);
        return Lists.newArrayList(StrUtil.join(StrUtil.DOT, dataName, tableName));
    }

    /**
     * 获取数据库名称
     */
    public String getDataBaseName(ShardingValueContext<T> shardingValueContext) {
        // 优先从上下文中获取数据库名称
        String dataSourceName = getDataSourceNameFromContext(shardingValueContext);
        if (StrUtil.isNotBlank(dataSourceName)) {
            return dataSourceName;
        }
        return getGetDataSourceNameFromShardKey(shardingValueContext);
    }

    /**
     * 获取表名称
     */
    public String getTableName(ShardingValueContext<T> shardingValue) {
        // 先从上下文中获取表名
        String tableName = getTableNameFromContext(shardingValue);
        if (StrUtil.isNotBlank(tableName)) {
            return tableName;
        }
        return getGetTableNameFromShardKey(shardingValue);
    }

    /**
     * 从上下文中获取数据名称分片规则
     */
    public String getDataSourceNameFromContext(ShardingValueContext<T> shardingValueContext) {
        // 先从上下文中获取路由规则
        Collection<Comparable<?>> databaseShardingValues = HintManagerHandler.getDatabaseShardingValues();
        return getNameFromContext(getDataBaseNameExpression(shardingValueContext.getLogicTableName()),  databaseShardingValues);
    }

    /**
     * 从上下文中获取表名称分片规则
     */
    public String getTableNameFromContext(ShardingValueContext<T> shardingValue) {
        // 先从上下文中获取路由规则
        Collection<Comparable<?>> tableShardingValues = HintManagerHandler.getTableShardingValues();
        return getNameFromContext(getTableNameExpression(shardingValue.getLogicTableName()), tableShardingValues);
    }

    public String getDataBaseNameExpression(String logicTableName) {
        if (StrUtil.isNotBlank(this.dataBaseNameExpression)) {
            return this.dataBaseNameExpression;
        }
        this.dataBaseNameExpression = getShardingItem(logicTableName).getDataBaseNameExpression();
        if (StrUtil.isBlank(this.dataBaseNameExpression)) {
            throw new BusinessException("数据名称表达式为空");
        }
        return this.dataBaseNameExpression;
    }

    public String getTableNameExpression(String logicTableName) {
        if (StrUtil.isNotBlank(this.tableNameExpression)) {
            return this.tableNameExpression;
        }
        this.tableNameExpression = getShardingItem(logicTableName).getTableNameExpression();
        if (StrUtil.isBlank(this.tableNameExpression)) {
            throw new BusinessException("表名表达式为空");
        }
        return this.tableNameExpression;
    }


    /**
     * 获取数据库名称-不是数据源的名称
     */
    public abstract String getGetDataSourceNameFromShardKey(ShardingValueContext<T> shardingValue);

    /**
     * 获取表名称
     */
    public abstract String getGetTableNameFromShardKey(ShardingValueContext<T> shardingValue);


}
